In [ ]:
import torch
import os
import torchvision
from torchvision.transforms import v2
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torch.nn as nn
from PIL import Image, ImageOps
import torch
import pdb
import numpy as np
import yaml
from tqdm import tqdm
import sys
import matplotlib.pyplot as plt
sys.path.append('..')
from template import utils
from torchvision.utils import save_image

import torch.nn.functional as F
In [ ]:
import warnings
warnings.filterwarnings("ignore")
In [ ]:
# setting config
config = yaml.safe_load(open("config.yaml"))
batch_size = int(config["BATCH_SIZE"])

print(f"Our config: {config}")
Our config: {'BATCH_SIZE': 64, 'NUM_EPOCHS': 10, 'LR': '3e-4'}
In [ ]:
transform = transforms.Compose([
    transforms.ToTensor(),
    v2.Resize((128, 128))
])
In [ ]:
train_dataset = torchvision.datasets.CelebA(root='./data', split='train',
                                        download=True, transform=transform)
valid_dataset = torchvision.datasets.CelebA(root='./data', split='valid',
                                       download=True, transform=transform)
test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
                                       download=True, transform=transform)

#create dataloaders
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
validloader = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified

Display images from the dataset¶

In [ ]:
imgs, labels = next(iter(testloader))
print(f"Image Shapes: {imgs.shape}")
print(f"Label Shapes: {labels.shape}")
Image Shapes: torch.Size([64, 3, 128, 128])
Label Shapes: torch.Size([64, 40])
In [ ]:
N_IMGS = 8
fig, ax = plt.subplots(1,N_IMGS)
fig.set_size_inches(3 * N_IMGS, 3)

ids = np.random.randint(low=0, high=len(train_dataset), size=N_IMGS)

for i, n in enumerate(ids):
    img = train_dataset[n][0].numpy().reshape(3,128,128).transpose(1, 2, 0)
    ax[i].imshow(img)
    #ax[i].set_title(f"Img #{n}  Label: {train_dataset[n][1]}")
    #ax[i].axis("off")
plt.show()
In [ ]:
def save_model(model, optimizer, epoch, stats, exp_no = 40120242):
    """ Saving model checkpoint """
    
    if(not os.path.exists("experiments/experiment_"+str(exp_no)+"/models")):
        os.makedirs("experiments/experiment_"+str(exp_no)+"/models")
    savepath = "experiments/experiment_"+str(exp_no)+f"/models/checkpoint_epoch_{epoch}.pth"

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats
    }, savepath)
    
    return


def load_model(model, optimizer, savepath):
    """ Loading pretrained checkpoint """
    
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint["epoch"]
    stats = checkpoint["stats"]
    
    return model, optimizer, epoch, stats
In [ ]:
def train_epoch(model, train_loader, optimizer, criterion, epoch, device, lambda_kld = 1e-03):
    """ Training a model for one epoch """
    
    loss_list = []
    recons_loss = []
    vae_loss = []
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, _) in progress_bar:
        images = images.to(device)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        # Forward pass
        recons, (z, mu, log_var) = model(images)
         
        # Calculate Loss
        loss, (mse, kld) = criterion(recons, images, mu, log_var, lambda_kld)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        vae_loss.append(kld.item())
        
        # Getting gradients w.r.t. parameters
        loss.backward()
         
        # Updating parameters
        optimizer.step()
        
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
        
    mean_loss = np.mean(loss_list)
    
    return mean_loss, loss_list


@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", writer=None, lambda_kld = 1e-03):
    """ Evaluating the model for either validation or test """
    loss_list = []
    recons_loss = []
    kld_loss = []
    
    for i, (images, _) in enumerate(eval_loader):
        images = images.to(device)
        
        # Forward pass 
        recons, (z, mu, log_var) = model(images)
                 
        loss, (mse, kld) = criterion(recons, images, mu, log_var, lambda_kld)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        kld_loss.append(kld.item())
        
        if(i==0 and savefig):
            save_image(recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png"))
            
    # Total correct predictions and loss
    loss = np.mean(loss_list)
    recons_loss = np.mean(recons_loss)
    kld_loss = np.mean(kld_loss)
    return loss, recons_loss, kld_loss


def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader,
                num_epochs, writer,savepath="", save_frequency=2,lambda_kld = 1e-03):
    """ Training a model for a given number of epochs"""
    
    train_loss = []
    val_loss =  []
    val_loss_recons =  []
    val_loss_kld =  []
    loss_iters = []
    
    for epoch in range(num_epochs):
           
        # validation epoch
        model.eval()  # important for dropout and batch norms
        log_epoch = (epoch % save_frequency == 0 or epoch == num_epochs - 1)
        loss, recons_loss, kld_loss = eval_model(
                model=model, eval_loader=valid_loader, criterion=criterion,
                device=device, epoch=epoch, savefig=log_epoch, savepath=savepath,
                writer=writer, lambda_kld = lambda_kld
            )
        val_loss.append(loss)
        val_loss_recons.append(recons_loss)
        val_loss_kld.append(kld_loss)

        
        # training epoch
        model.train()  # important for dropout and batch norms
        mean_loss, cur_loss_iters = train_epoch(
                model=model, train_loader=train_loader, optimizer=optimizer,
                criterion=criterion, epoch=epoch, device=device, lambda_kld = lambda_kld 
            )
        
        # PLATEAU SCHEDULER
        scheduler.step(val_loss[-1])
        train_loss.append(mean_loss)
        loss_iters = loss_iters + cur_loss_iters
        
        if(epoch % save_frequency == 0):
            stats = {
                "train_loss": train_loss,
                "valid_loss": val_loss,
                "loss_iters": loss_iters
            }
            save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats)
        
        if(log_epoch):
            print(f"    Train loss: {round(mean_loss, 5)}")
            print(f"    Valid loss: {round(loss, 5)}")
            print(f"       Valid loss recons: {round(val_loss_recons[-1], 5)}")
            print(f"       Valid loss KL-D:   {round(val_loss_kld[-1], 5)}")
    
    print(f"Training completed")
    return train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld

In the following we will outline some obvservations made when choosing the architecture of our ConvVAE. When employing a kernel size of 2, the images exhibited a blocky appearance, prompting a switch to a kernel size of 3 with a smaller stride, resulting in more satisfactory outcomes by mitigating the blockiness. Attempts to enhance the model by increasing the number of linear layers in the encoder led to undesired brownish artifacts. Addressing this issue with supplementary convolutional operations at the end of the network helped alleviate the problem. The introduction of dropouts along with additional convolutional layers resulted in images having a brownish tint, despite an increase in details.

In [ ]:
class ConvVAE(nn.Module):
    
    def __init__(self):
        super(ConvVAE, self).__init__()

        # Define Convolutional Encoders
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels = 8, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(in_channels = 8, out_channels = 16, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2),
            nn.Flatten(),
        )

        # Define mean and variance
        self.mu = nn.Linear(8192, 200)
        
        # Note: we learn the log variance to make training easier (allows negative values)
        self.log_var = nn.Linear(8192, 200)

        # Define decoder
        self.decoder = nn.Sequential(
            nn.Linear(200, 8192),
            nn.ReLU(),
            nn.Unflatten(dim = 1, unflattened_size=(32, 16, 16)),
            nn.ConvTranspose2d(in_channels = 32, out_channels=32, kernel_size = 3, stride = 2, padding=0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32, out_channels = 16, kernel_size = 3, stride = 2, padding = 0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 16, kernel_size = 3, stride = 2, padding = 0),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=8, stride = 1, padding = 0),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=3, kernel_size=3, stride = 1, padding = 1),
        )

    def reparameterize(self, mu, log_var):
        """ Reparametrization trick"""
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)  # random sampling happens here
        z = mu + std * eps
        return z

    def forward(self, x):
        x = self.encoder(x)
        mean = self.mu(x)
        log_var = self.log_var(x)

        z = self.reparameterize(mean, log_var)
        x_hat = self.decoder(z)
        return x_hat, (z, mean, log_var)
In [ ]:
def vae_loss_function(recons, target, mu, log_var, lambda_kld=1e-3):
    """
    Combined loss function for joint optimization of 
    reconstruction and ELBO
    """
    recons_loss = F.mse_loss(recons, target)
    # Deriving kld for vaes: https://stats.stackexchange.com/questions/318748/deriving-the-kl-divergence-loss-for-vaes
    kld = (-0.5 * (1 + log_var - mu**2 - log_var.exp()).sum(dim=1)).mean(dim=0)  # closed-form solution of KLD in Gaussian
    loss = recons_loss + lambda_kld * kld

    return loss, (recons_loss, kld)
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cvae = ConvVAE()
criterion = vae_loss_function
optimizer = torch.optim.Adam(cvae.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3, factor = 0.5, verbose = True)
cvae = cvae.to(device)
In [ ]:
savepath = "/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_9"
In [ ]:
'''train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
        model=cvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
        train_loader=trainloader, valid_loader=validloader, num_epochs=15, savepath=savepath, writer=None)'''
Out[ ]:
'train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(\n        model=cvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,\n        train_loader=trainloader, valid_loader=validloader, num_epochs=15, savepath=savepath, writer=None)'
In [ ]:
cvae = ConvVAE().to(device)
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
In [ ]:
utils.plot_loss_epoch(stats['train_loss'][:48], stats['valid_loss'][1:])
recons_loss = stats['other_loss_stats'][0]
Kl_d_loss = stats['other_loss_stats'][1]

epochs = range(1, len(Kl_d_loss) + 1)

plt.figure(figsize=(10, 5))
plt.plot(epochs, Kl_d_loss, 'bo-', label='KLD Loss')
plt.title('KLD Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(epochs, recons_loss, 'bo-', label='Reconstruction Loss')
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

Most of the loss development occurs in the earlier steps. We note that we omitted the first step in this plot to make the later changes in loss visible in the figure. Later on we only see minimal improvement. This is also reflected by the quality of the reconstruction images of different epochs. While the main changes (i.e. face becomes visible, blockiness decrease) occurs in earlier epochs. Later epochs are mostly making the reconstructions more detailed. The KL-Divergence converges quickly to a value around 10.5. As expected, given the larger weighting the Reconstruction loss looks very similar to the train/validation loss.

In [ ]:
# Generate more images
with torch.no_grad():
    for i in range(5):
        z = torch.randn(64, 200).to(device)
        sample = cvae.decoder(z)
    

recons = sample.view(64, 3, 128, 128).cpu()
In [ ]:
fig, axes = plt.subplots(1, 10, figsize=(128, 128))  # Adjust figsize as needed

for i in range(10):
    img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')  # Turn off axis labels for clarity

plt.tight_layout()
plt.show()
In [ ]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)

with torch.no_grad():
    sample, _ = cvae(test_data)

    recons = sample.view(batch_size, 3, 128, 128).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(128, 128))  # Adjust figsize as needed

    for i in range(10):
        img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
        # test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
        test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
        axes[0][i].imshow(test_img)
        axes[0][i].axis('off')
        axes[1][i].imshow(img)
        axes[1][i].axis('off')  # Turn off axis labels for clarity

    plt.tight_layout()
    plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

The reconstruction results of the standard ConvVAE are decent. The faces are clearly visible and some definining features of the reconstrucions are visible as well. One downside is, that the faces look very similar like an "average" of all faces. This is expected however and one of the reasons VQVAE and resulting architectures gained more popularity. Similar comments can be made about the generation results. We do believe that improving our architecture even further would lead to even better results.

Implement a ConvVAE using a pretrained encoder (e.g. ResNet18, ConvNeXt-Tiny, ...)¶

We decided to use ResNet18 pretrained encoder. The class below represents a basic block, which resnet 18 consists of. One important difference between this basic block and the one used in original resnet, is that here, we use nn.ConvTranspose2d instead nn.Conv2d. This is because we use it in the decoder, which is supposed to be a mirrored version of an encoder. Generally it is true, architecture is the same in the encoder and the decoder except number of filters in the decoder layers. The number(s) needed to be changed because, without doing so, final reconstructed image dimensions would be different than expected (3, 128, 128). Of course, this issue can be fixed by resizing the final image, but this created weird artifacts in the reconstructed image. For finetuning we chose end-to-end finetuning, given its a simple choice, guaranteeing decent results.

In [ ]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=2, transpose=False):
        super(BasicBlock, self).__init__()
        self.transpose = transpose
        output_padding = 1 if stride > 1 else 0
        self.conv1 = nn.ConvTranspose2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, \
                                        output_padding=output_padding, bias=False)
        self.conv2 = nn.ConvTranspose2d(planes, planes, kernel_size=3, stride=1, padding=1, \
                                        output_padding=0, bias=False)

        self.bn1 = nn.BatchNorm2d(planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, output_padding=output_padding, bias=False),
              #  nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x))) #self.bn1()
        out = self.bn2(self.conv2(out)) #self.bn2()
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class Decoder(nn.Module):
    def __init__(self, block, num_blocks, latent_dim):
        super(Decoder, self).__init__()
        self.in_planes = 64*block.expansion
        self.fc2 = nn.Linear(latent_dim, 4096)
        
        self.layer1 = self._make_layer(block, 32, num_blocks[3], stride=2, transpose=True) #filter numbers are decreased but architecture remains the same
        self.layer2 = self._make_layer(block, 16, num_blocks[2], stride=2, transpose=True)
        self.layer3 = self._make_layer(block, 8, num_blocks[1], stride=2, transpose=True)
        self.layer4 = self._make_layer(block, 8, num_blocks[0], stride= 1, transpose=True)
        self.conv1 = nn.ConvTranspose2d(8, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(3)

    def _make_layer(self, block, planes, num_blocks, stride, transpose):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in reversed(strides):
            layers.append(block(self.in_planes, planes, stride, transpose))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.fc2(x))
        out = out.view(out.size(0),64,8,8)  # reshape output of linear layer
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.relu(self.bn1(self.conv1(out))) 
        return out

def DecoderResNet18(latent_dim):
    return Decoder(BasicBlock, [2,2,2,2], latent_dim)
In [ ]:
class ResNetVAE(nn.Module):
    def __init__(self, latent_dim = 512):
        super(ResNetVAE, self).__init__()
        
        resnet = torchvision.models.resnet18(weights='DEFAULT')
        self.encoder = resnet
        
        # Note: 1000 marks the output dim of resnet
        self.mu = nn.Linear(1000, latent_dim)
        
        # Note: we learn the log variance to make training easier (allows negative values)
        self.log_var = nn.Linear(1000, latent_dim)
        
        self.decoder = DecoderResNet18(latent_dim)

        
    def reparameterize(self, mu, log_var):
        """ Reparametrization trick"""
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)  # random sampling happens here
        z = mu + std * eps
        return z

    def forward(self, x):
        x = self.encoder(x)
        mu = self.mu(x)
        log_var = self.log_var(x)
        z = self.reparameterize(mu, log_var)
        x_hat = self.decoder(z)
        return x_hat, (z, mu, log_var)
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resnetvae = ResNetVAE(512)
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 4, factor = 0.5, verbose = True)
resnetvae = resnetvae.to(device)
In [ ]:
num_epochs = 50
In [ ]:
device
Out[ ]:
'cuda'
In [ ]:
train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
        model=resnetvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
        train_loader=trainloader, valid_loader=validloader, num_epochs=num_epochs, writer=None,lambda_kld=1e-4)
Epoch 1 Iter 2544: loss 0.03966. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [02:39<00:00, 15.94it/s]
    Train loss: 0.07908
    Valid loss: 0.57485
       Valid loss recons: 0.28432
       Valid loss KL-D:   2905.30506
Epoch 2 Iter 2544: loss 0.02919. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.56it/s]
Epoch 3 Iter 2544: loss 0.02737. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s]
    Train loss: 0.02353
    Valid loss: 0.02543
       Valid loss recons: 0.0209
       Valid loss KL-D:   45.28074
Epoch 4 Iter 2544: loss 0.02349. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.53it/s]
Epoch 5 Iter 2544: loss 0.02303. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.02179
    Valid loss: 0.02297
       Valid loss recons: 0.01813
       Valid loss KL-D:   48.41507
Epoch 6 Iter 2544: loss 0.02189. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Epoch 7 Iter 2544: loss 0.02221. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s]
    Train loss: 0.02095
    Valid loss: 0.02095
       Valid loss recons: 0.01612
       Valid loss KL-D:   48.29697
Epoch 8 Iter 2544: loss 0.02217. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
Epoch 9 Iter 2544: loss 0.02239. : 100%|█████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
    Train loss: 0.02037
    Valid loss: 0.0205
       Valid loss recons: 0.01553
       Valid loss KL-D:   49.74597
Epoch 10 Iter 2544: loss 0.01970. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
Epoch 11 Iter 2544: loss 0.01845. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
    Train loss: 0.02
    Valid loss: 0.01985
       Valid loss recons: 0.01494
       Valid loss KL-D:   49.08961
Epoch 12 Iter 2544: loss 0.01988. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s]
Epoch 13 Iter 2544: loss 0.02589. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
    Train loss: 0.02088
    Valid loss: 0.0196
       Valid loss recons: 0.01457
       Valid loss KL-D:   50.2956
Epoch 14 Iter 2544: loss 0.01909. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s]
Epoch 15 Iter 2544: loss 0.02095. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
    Train loss: 0.01976
    Valid loss: 0.01961
       Valid loss recons: 0.01432
       Valid loss KL-D:   52.85515
Epoch 16 Iter 2544: loss 0.01989. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.46it/s]
Epoch 17 Iter 2544: loss 0.02059. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01954
    Valid loss: 0.01948
       Valid loss recons: 0.01435
       Valid loss KL-D:   51.30299
Epoch 18 Iter 2544: loss 0.02027. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
Epoch 19 Iter 2544: loss 0.02018. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
    Train loss: 0.01958
    Valid loss: 0.01942
       Valid loss recons: 0.01421
       Valid loss KL-D:   52.10604
Epoch 20 Iter 2544: loss 0.01931. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Epoch 21 Iter 2544: loss 0.02264. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
    Train loss: 0.01934
    Valid loss: 0.01925
       Valid loss recons: 0.01401
       Valid loss KL-D:   52.39189
Epoch 22 Iter 2544: loss 0.02076. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.54it/s]
Epoch 23 Iter 2544: loss 0.01886. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01924
    Valid loss: 0.01923
       Valid loss recons: 0.01397
       Valid loss KL-D:   52.6373
Epoch 24 Iter 2544: loss 0.01902. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s]
Epoch 25 Iter 2544: loss 0.01988. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s]
    Train loss: 0.01916
    Valid loss: 0.01934
       Valid loss recons: 0.01399
       Valid loss KL-D:   53.5481
Epoch 26 Iter 2544: loss 0.02002. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s]
Epoch 27 Iter 2544: loss 0.02156. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.27it/s]
    Train loss: 0.01908
    Valid loss: 0.019
       Valid loss recons: 0.01366
       Valid loss KL-D:   53.37492
Epoch 28 Iter 2544: loss 0.01837. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
Epoch 29 Iter 2544: loss 0.02175. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01902
    Valid loss: 0.01911
       Valid loss recons: 0.01375
       Valid loss KL-D:   53.59315
Epoch 30 Iter 2544: loss 0.01916. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
Epoch 31 Iter 2544: loss 0.02035. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.51it/s]
    Train loss: 0.01897
    Valid loss: 0.01894
       Valid loss recons: 0.01351
       Valid loss KL-D:   54.28299
Epoch 32 Iter 2544: loss 0.02010. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.57it/s]
Epoch 33 Iter 2544: loss 0.02093. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01891
    Valid loss: 0.01885
       Valid loss recons: 0.01343
       Valid loss KL-D:   54.13779
Epoch 34 Iter 2544: loss 0.01982. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.44it/s]
Epoch 35 Iter 2544: loss 0.02034. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
    Train loss: 0.01885
    Valid loss: 0.01887
       Valid loss recons: 0.01341
       Valid loss KL-D:   54.5866
Epoch 36 Iter 2544: loss 0.02112. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
Epoch 37 Iter 2544: loss 0.01945. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
    Train loss: 0.01881
    Valid loss: 0.01879
       Valid loss recons: 0.0135
       Valid loss KL-D:   52.97335
Epoch 38 Iter 2544: loss 0.01626. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
Epoch 39 Iter 2544: loss 0.01797. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.46it/s]
    Train loss: 0.01878
    Valid loss: 0.01879
       Valid loss recons: 0.0134
       Valid loss KL-D:   53.91221
Epoch 40 Iter 2544: loss 0.02491. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.50it/s]
Epoch 41 Iter 2544: loss 0.01838. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s]
    Train loss: 0.01874
    Valid loss: 0.01876
       Valid loss recons: 0.01336
       Valid loss KL-D:   54.01153
Epoch 42 Iter 2544: loss 0.02161. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
Epoch 43 Iter 2544: loss 0.01922. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s]
    Train loss: 0.01885
    Valid loss: 0.01864
       Valid loss recons: 0.01319
       Valid loss KL-D:   54.45878
Epoch 44 Iter 2544: loss 0.01805. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.53it/s]
Epoch 45 Iter 2544: loss 0.01930. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01869
    Valid loss: 0.01871
       Valid loss recons: 0.01325
       Valid loss KL-D:   54.59709
Epoch 46 Iter 2544: loss 0.01862. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.55it/s]
Epoch 47 Iter 2544: loss 0.01874. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.01868
    Valid loss: 0.01912
       Valid loss recons: 0.01363
       Valid loss KL-D:   54.88449
Epoch 48 Iter 2544: loss 0.02108. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s]
Epoch 49 Iter 2544: loss 0.01891. : 100%|████████████████████████████████████████████████████████| 2544/2544 [01:46<00:00, 23.86it/s]
Epoch 00049: reducing learning rate of group 0 to 5.0000e-04.
    Train loss: 0.01864
    Valid loss: 0.01892
       Valid loss recons: 0.01329
       Valid loss KL-D:   56.22169
Epoch 50 Iter 2544: loss 0.02033. : 100%|████████████████████████████████████████████████████████| 2544/2544 [08:38<00:00,  4.91it/s]
    Train loss: 0.01845
    Valid loss: 0.01856
       Valid loss recons: 0.01306
       Valid loss KL-D:   54.98609
Training completed

In [ ]:
utils.plot_loss_epoch(train_loss, val_loss)
In [ ]:
resnetvae = ResNetVAE(512)
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
_, _, _, stats = load_model(resnetvae, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
In [ ]:
# Generate more images
with torch.no_grad():
    for i in range(5):
        z = torch.randn(64, 512).to(device)
        sample = resnetvae.decoder(z)
    

recons = sample.view(64, 3, 128, 128).cpu()
In [ ]:
fig, axes = plt.subplots(1, 10, figsize=(128, 128))  # Adjust figsize as needed

for i in range(10):
    img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')  # Turn off axis labels for clarity

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)

with torch.no_grad():
    sample, _ = resnetvae(test_data)

    recons = sample.view(batch_size, 3, 128, 128).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(128, 128))  # Adjust figsize as needed

    for i in range(10):
        img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
        # test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
        test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
        axes[0][i].imshow(test_img)
        axes[0][i].axis('off')
        axes[1][i].imshow(img)
        axes[1][i].axis('off')  # Turn off axis labels for clarity

    plt.tight_layout()
    plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Let's check the influence of the KL divergence lambda parameter by increasing it by the factor of 10

In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
resnetvae = ResNetVAE(512)
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 4, factor = 0.5, verbose = True)
resnetvae = resnetvae.to(device)
In [ ]:
train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
        model=resnetvae, optimizer=optimizer, scheduler=scheduler, criterion=vae_loss_function,
        train_loader=trainloader, valid_loader=validloader, num_epochs=num_epochs, writer=None,lambda_kld=1e-3)
Epoch 1 Iter 2544: loss 0.05426. : 100%|████████████████████████████████████████████| 2544/2544 [07:37<00:00,  5.56it/s]
    Train loss: 0.07213
    Valid loss: 2.92893
       Valid loss recons: 0.28448
       Valid loss KL-D:   2644.44034
Epoch 2 Iter 2544: loss 0.04735. : 100%|████████████████████████████████████████████| 2544/2544 [02:56<00:00, 14.38it/s]
Epoch 3 Iter 2544: loss 0.03960. : 100%|████████████████████████████████████████████| 2544/2544 [02:51<00:00, 14.81it/s]
    Train loss: 0.04312
    Valid loss: 0.0437
       Valid loss recons: 0.03463
       Valid loss KL-D:   9.06563
Epoch 4 Iter 2544: loss 0.03926. : 100%|████████████████████████████████████████████| 2544/2544 [01:58<00:00, 21.47it/s]
Epoch 5 Iter 2544: loss 0.04034. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
    Train loss: 0.04175
    Valid loss: 0.04172
       Valid loss recons: 0.03235
       Valid loss KL-D:   9.37758
Epoch 6 Iter 2544: loss 0.04283. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.35it/s]
Epoch 7 Iter 2544: loss 0.04417. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
    Train loss: 0.0406
    Valid loss: 0.04049
       Valid loss recons: 0.03125
       Valid loss KL-D:   9.23116
Epoch 8 Iter 2544: loss 0.04002. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
Epoch 9 Iter 2544: loss 0.04285. : 100%|████████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
    Train loss: 0.04022
    Valid loss: 0.04003
       Valid loss recons: 0.03053
       Valid loss KL-D:   9.50145
Epoch 10 Iter 2544: loss 0.03665. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s]
Epoch 11 Iter 2544: loss 0.04062. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
    Train loss: 0.03988
    Valid loss: 0.03968
       Valid loss recons: 0.03
       Valid loss KL-D:   9.68452
Epoch 12 Iter 2544: loss 0.03864. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
Epoch 13 Iter 2544: loss 0.04046. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s]
    Train loss: 0.0397
    Valid loss: 0.03939
       Valid loss recons: 0.03
       Valid loss KL-D:   9.38439
Epoch 14 Iter 2544: loss 0.03955. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.25it/s]
Epoch 15 Iter 2544: loss 0.03783. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
    Train loss: 0.03959
    Valid loss: 0.03923
       Valid loss recons: 0.02957
       Valid loss KL-D:   9.65668
Epoch 16 Iter 2544: loss 0.04080. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s]
Epoch 17 Iter 2544: loss 0.03856. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
    Train loss: 0.03946
    Valid loss: 0.03912
       Valid loss recons: 0.02938
       Valid loss KL-D:   9.73264
Epoch 18 Iter 2544: loss 0.03900. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.25it/s]
Epoch 19 Iter 2544: loss 0.06096. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
    Train loss: 0.03935
    Valid loss: 0.03908
       Valid loss recons: 0.02913
       Valid loss KL-D:   9.94453
Epoch 20 Iter 2544: loss 0.04351. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s]
Epoch 21 Iter 2544: loss 0.03699. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s]
    Train loss: 0.03922
    Valid loss: 0.03893
       Valid loss recons: 0.02917
       Valid loss KL-D:   9.75982
Epoch 22 Iter 2544: loss 0.04202. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.42it/s]
Epoch 23 Iter 2544: loss 0.04551. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.31it/s]
    Train loss: 0.0391
    Valid loss: 0.03891
       Valid loss recons: 0.02908
       Valid loss KL-D:   9.83609
Epoch 24 Iter 2544: loss 0.03520. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s]
Epoch 25 Iter 2544: loss 0.03931. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
    Train loss: 0.03894
    Valid loss: 0.03877
       Valid loss recons: 0.02879
       Valid loss KL-D:   9.98364
Epoch 26 Iter 2544: loss 0.03959. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.26it/s]
Epoch 27 Iter 2544: loss 0.04638. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.37it/s]
    Train loss: 0.03886
    Valid loss: 0.03864
       Valid loss recons: 0.02846
       Valid loss KL-D:   10.18017
Epoch 28 Iter 2544: loss 0.03527. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
Epoch 29 Iter 2544: loss 0.03818. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
    Train loss: 0.03883
    Valid loss: 0.0385
       Valid loss recons: 0.02849
       Valid loss KL-D:   10.00849
Epoch 30 Iter 2544: loss 0.03845. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.39it/s]
Epoch 31 Iter 2544: loss 0.03408. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
    Train loss: 0.03869
    Valid loss: 0.03851
       Valid loss recons: 0.02832
       Valid loss KL-D:   10.19438
Epoch 32 Iter 2544: loss 0.03660. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.33it/s]
Epoch 33 Iter 2544: loss 0.04607. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
    Train loss: 0.0386
    Valid loss: 0.03849
       Valid loss recons: 0.02827
       Valid loss KL-D:   10.22045
Epoch 34 Iter 2544: loss 0.04096. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.24it/s]
Epoch 35 Iter 2544: loss 0.03882. : 100%|███████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.48it/s]
    Train loss: 0.03854
    Valid loss: 0.03828
       Valid loss recons: 0.02809
       Valid loss KL-D:   10.19361
Epoch 36 Iter 2544: loss 0.03771. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.46it/s]
Epoch 37 Iter 2544: loss 0.03488. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
    Train loss: 0.03849
    Valid loss: 0.03835
       Valid loss recons: 0.02809
       Valid loss KL-D:   10.25628
Epoch 38 Iter 2544: loss 0.03541. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
Epoch 39 Iter 2544: loss 0.03638. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.27it/s]
    Train loss: 0.03845
    Valid loss: 0.03826
       Valid loss recons: 0.02804
       Valid loss KL-D:   10.21794
Epoch 40 Iter 2544: loss 0.03822. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.41it/s]
Epoch 41 Iter 2544: loss 0.04498. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.38it/s]
    Train loss: 0.03843
    Valid loss: 0.03815
       Valid loss recons: 0.02773
       Valid loss KL-D:   10.42156
Epoch 42 Iter 2544: loss 0.03487. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.45it/s]
Epoch 43 Iter 2544: loss 0.03993. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.36it/s]
    Train loss: 0.03841
    Valid loss: 0.03811
       Valid loss recons: 0.02763
       Valid loss KL-D:   10.48405
Epoch 44 Iter 2544: loss 0.04279. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.40it/s]
Epoch 45 Iter 2544: loss 0.03781. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.43it/s]
    Train loss: 0.03835
    Valid loss: 0.03827
       Valid loss recons: 0.02803
       Valid loss KL-D:   10.23958
Epoch 46 Iter 2544: loss 0.04364. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.28it/s]
Epoch 47 Iter 2544: loss 0.03749. : 100%|███████████████████████████████████████████| 2544/2544 [01:43<00:00, 24.49it/s]
    Train loss: 0.03834
    Valid loss: 0.03816
       Valid loss recons: 0.02794
       Valid loss KL-D:   10.22188
Epoch 48 Iter 2544: loss 0.03406. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.29it/s]
Epoch 49 Iter 2544: loss 0.04041. : 100%|███████████████████████████████████████████| 2544/2544 [01:44<00:00, 24.35it/s]
    Train loss: 0.03831
    Valid loss: 0.03822
       Valid loss recons: 0.02795
       Valid loss KL-D:   10.27694
Epoch 50 Iter 2544: loss 0.04041. : 100%|███████████████████████████████████████████| 2544/2544 [01:45<00:00, 24.19it/s]
    Train loss: 0.0383
    Valid loss: 0.03812
       Valid loss recons: 0.02782
       Valid loss KL-D:   10.30331
Training completed

In [ ]:
savepath = "/experiments/experiment_40120242"
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
In [ ]:
criterion = vae_loss_function
optimizer = torch.optim.Adam(resnetvae.parameters(), lr=1e-3)
In [ ]:
resnetvae = ResNetVAE().to(device)
resnetvae, optimizer, epoch, stats = load_model(resnetvae, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
In [ ]:
utils.plot_loss_epoch(stats['train_loss'][:49], stats['valid_loss'])

In both cases (KLD lambda/weight higher and lower), training loss value hardly changes except the first epoch. Becauseof that we used learning rate scheduler to decrease the learning rate after 4 epochs, when a significant change in validation loss wasn't noticed.

In [ ]:
# Generate more images
with torch.no_grad():
    for i in range(5):
        z = torch.randn(64, 512).to(device)
        sample = resnetvae.decoder(z)
    

recons = sample.view(64, 3, 128, 128).cpu()

fig, axes = plt.subplots(1, 10, figsize=(128, 128))  # Adjust figsize as needed

for i in range(10):
    img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')  # Turn off axis labels for clarity

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
In [ ]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)

with torch.no_grad():
    sample, _ = resnetvae(test_data)

    recons = sample.view(batch_size, 3, 128, 128).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(128, 128))  # Adjust figsize as needed

    for i in range(10):
        img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
        # test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
        test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
        axes[0][i].imshow(test_img)
        axes[0][i].axis('off')
        axes[1][i].imshow(img)
        axes[1][i].axis('off')  # Turn off axis labels for clarity

    plt.tight_layout()
    plt.show()

Manipulating kl divergence lambda value can lead to vastly different results. When KL Divergence lambda value was equal to 1e-4, faces were more diverse than when the value was equal to 1e-3. A more encompassing discussion of this issue can be found in the appended thread: https://stats.stackexchange.com/questions/332179/how-to-weight-kld-loss-vs-reconstruction-loss-in-variational-auto-encoder. Basically, "higher values (of the KL divergence parameter) give a more structured latent space at the cost of poorer reconstruction, and lower values give better reconstruction with a less structured latent space" [1]. This makes sense and can be perfectly seen here. In the end, reconstructions created by the first model, were after all better, simply because of the better reconstruction. The thing we learned is, that a proper KL divergence weighting value must be chosen in order to achieve satisfying results.

Frechet Inception Distance¶

We use the implementation of torchmetrics to calculate the frechet inception distance.

In [ ]:
import torchmetrics
In [ ]:
from torchmetrics.image.fid import FrechetInceptionDistance

Frechet Inception Distance ConvVAE¶

In [ ]:
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize = True)
# Torchmetrics implementation does not use cuda for some reason
device = 'cpu'
cvae = ConvVAE()
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
cvae = cvae.to(device)

test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
                                       download=True, transform=transform)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size = 64)

progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
    test_data = test_data.to(device)
    z = torch.randn(64, 200).to(device)
    sample = cvae.decoder(z)
    fid.update(sample, real=False)
    fid.update(test_data, real=True)
fid.compute()
Files already downloaded and verified
100%|██████████| 312/312 [13:08<00:00,  2.53s/it]
Out[ ]:
tensor(5.0324)
In [ ]:
print("FID for Vanilla ConvVAE (without pretrained encoder) is equal to 5.0324")
FID for Vanilla ConvVAE (without pretrained encoder) is equal to 5.0324

Frechet Inception Distance ResNetVAE¶

FID for ResNetVAE with lower KL Divergence value.¶

In [ ]:
fid = FrechetInceptionDistance(feature=64, normalize = True)
device = 'cpu'
In [ ]:
resnetvae_KL_lower = ResNetVAE(512).to(device)
optimizer = torch.optim.Adam(resnetvae_KL_lower.parameters(), lr=1e-3)
resnetvae_KL_lower, optimizer, epoch, stats = load_model(resnetvae_KL_lower, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
In [ ]:
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
    test_data = test_data.to(device)
    z = torch.randn(64, 512).to(device)
    sample = resnetvae_KL_lower.decoder(z)
    fid.update(sample, real=False)
    fid.update(test_data, real=True)
fid.compute()
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [12:35<00:00,  2.42s/it]
Out[ ]:
tensor(4.4890)
In [ ]:
print("FID for ResNetVAE with KLD_lambda = 1e-4 is equal to 4.4890")
FID for ResNetVAE with KLD_lambda = 1e-4 is equal to 4.4890

FID for ResNetVAE with higher KL Divergence value.¶

In [ ]:
fid = FrechetInceptionDistance(feature=64, normalize = True)
In [ ]:
resnetvae_KL_higher = ResNetVAE(512).to(device)
optimizer = torch.optim.Adam(resnetvae_KL_higher.parameters(), lr=1e-3)
resnetvae_KL_higher, optimizer, epoch, stats = load_model(resnetvae_KL_higher, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
In [ ]:
progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
    test_data = test_data.to(device)
    z = torch.randn(64, 512).to(device)
    sample = resnetvae_KL_higher.decoder(z)
    fid.update(sample, real=False)
    fid.update(test_data, real=True)
fid.compute()
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 312/312 [12:44<00:00,  2.45s/it]
Out[ ]:
tensor(5.8770)
In [ ]:
print("FID for ResNetVAE with KLD_lambda = 1e-3 is equal to 5.8770")
FID for ResNetVAE with KLD_lambda = 1e-3 is equal to 5.8770

Discussion¶

FID is a metric for quantifying the realism and diversity of generated images. As we expected, the lowest FID value was achieved by ResNetVAE with a lower KLD weight. What is interesting, is that ConvVAE achieved better results that ResNetVAE with a higher KLD weight. In the end, we can empirically confirm those numbers just by looking at the generated images. The best looking ones are generated by ResNetVAE model with lower KLD weight and having that weight be higher, diminished the quality of the images. Nevertheless, we consider both ResNetVAE models to have higher quality generation than ConvVAE. Both ResNetVAE faces look less blocky and more natural than the ones of our ConvVAE. This holds especially for the lower KLD value model, even though the backgrounds are less natural and the faces are sometimes cut off. The reconstruction quality is similar. The ranking stays the same here. Especially the ResNet model with the lower KLD value shows more detail and quality for the reasons discussed above.

Analysis of Latent Space: ConvVAE & ResNetVAE¶

ConvVAE¶

In [ ]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
In [ ]:
from sklearn.metrics import jaccard_score
import matplotlib.pyplot as plt
import numpy as np
In [ ]:
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import linkage, fcluster
In [ ]:
def jaccard_similarity(x, y):
    return np.sum(np.minimum(x, y)) / np.sum(np.maximum(x, y))

def assign_clusters(labels, n_clusters):
    # Compute the pairwise distance matrix in the Jaccard metric
    dists = pdist(labels, metric=jaccard_similarity)
    
    # Compute the linkage matrix
    linkage_matrix = linkage(squareform(dists), method='average')
    
    # Assign clusters
    clusters = fcluster(linkage_matrix, n_clusters, criterion='maxclust')
    
    return clusters
In [ ]:
def display_projections(points, labels, num_classes, ax):
    # Ensure labels are a numpy array for convenience
    labels = np.array(labels)
    
    # Create a colormap
    cmap = plt.get_cmap('RdYlGn', num_classes)  # 20 classes

    for i in range(num_classes):  # For each class
        # Find points belonging to this class
        class_points = points[labels == i]
        
        # Plot these points with the color corresponding to this class
        ax.scatter(class_points[:, 0], class_points[:, 1], color=cmap(i), label=f'Group {i+1}')

    # Add a colorbar
    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=cmap), ax=ax)
In [ ]:
N = 2000
num_clusters = 40
In [ ]:
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
    for imgs, lbls in testloader:
        imgs = imgs.to(device)
        _, (z, _, _) = cvae(imgs)
        imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
        latents.append(z.cpu())
        labels.append(lbls)
        
imgs_flat = np.concatenate(imgs_flat)    
latents = np.concatenate(latents)
labels = np.concatenate(labels)
In [ ]:
t, l = next(iter(testloader))
In [ ]:
l.shape
Out[ ]:
torch.Size([64, 40])
In [ ]:
clusters = assign_clusters(labels[:N], num_clusters)
In [ ]:
pca_imgs
Out[ ]:
array([[-25.339378 , -12.632147 ],
       [-48.634953 , -25.977364 ],
       [-28.15862  ,   2.8991623],
       ...,
       [ 42.41321  ,  23.036709 ],
       [ -6.5465856,  17.496992 ],
       [-46.728096 ,  -2.5804029]], dtype=float32)
In [ ]:
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()
In [ ]:
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])

fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()

ResNetVAE¶

KL Divergence weight lower.¶

In [ ]:
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
    for imgs, lbls in testloader:
        imgs = imgs.to(device)
        _, (z, _, _) = resnetvae_KL_lower(imgs)
        imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
        latents.append(z.cpu())
        labels.append(lbls)
        
imgs_flat = np.concatenate(imgs_flat)    
latents = np.concatenate(latents)
labels = np.concatenate(labels)
In [ ]:
def jaccard_similarity(x, y):
    return np.sum(np.minimum(x, y)) / np.sum(np.maximum(x, y))

def assign_clusters(labels, n_clusters):
    # Compute the pairwise distance matrix in the Jaccard metric
    dists = pdist(labels, metric=jaccard_similarity)
    
    # Compute the linkage matrix
    linkage_matrix = linkage(squareform(dists), method='average')
    
    # Assign clusters
    clusters = fcluster(linkage_matrix, n_clusters, criterion='maxclust')
    
    return clusters
In [ ]:
N = 2000
num_clusters = 40
In [ ]:
clusters = assign_clusters(labels[:N], num_clusters)
In [ ]:
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
In [ ]:
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
#display_projections(pca_imgs[:N ], labels[:N], ax=ax[0]) 
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()
In [ ]:
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])

fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()

KL Divergence weight higher.¶

In [ ]:
imgs_flat, latents, labels = [], [], []
with torch.no_grad():
    for imgs, lbls in testloader:
        imgs = imgs.to(device)
        _, (z, _, _) = resnetvae_KL_higher(imgs)
        imgs_flat.append(imgs.cpu().view(imgs.shape[0],-1))
        latents.append(z.cpu())
        labels.append(lbls)
        
imgs_flat = np.concatenate(imgs_flat)    
latents = np.concatenate(latents)
labels = np.concatenate(labels)
In [ ]:
clusters = assign_clusters(labels[:N], num_clusters) # repeat the steps
In [ ]:
pca_imgs = PCA(n_components=2).fit_transform(imgs_flat[:N])
pca_latents = PCA(n_components=2).fit_transform(latents[:N])
fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(pca_imgs[:N], clusters,num_classes = num_clusters, ax = ax[0])
ax[0].set_title("PCA Proj. of Images")
display_projections(pca_latents[:N], clusters, num_classes = num_clusters, ax=ax[1])
ax[1].set_title("Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()
In [ ]:
N = 2000
tsne_imgs = TSNE(n_components=2).fit_transform(imgs_flat[:N])
tsne_latents = TSNE(n_components=2).fit_transform(latents[:N])

fig,ax = plt.subplots(1,2,figsize=(26,8))
display_projections(tsne_imgs[:N], clusters[:N], num_classes = num_clusters, ax=ax[0])
ax[0].set_title("T-SNE Proj. of Images")
display_projections(tsne_latents[:N], clusters[:N],num_classes = num_clusters, ax=ax[1])
ax[1].set_title("T-SNE Proj. of Encoded Representations")
ax[1].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
          fancybox=True, shadow=True, ncol=5)
plt.show()

Discussion¶

Given the large number of identities (labels), we tried to cluster the images ourselves.

In our problem context, we are dealing with 40-dimensional characteristic vectors, for instance, [1, 0, 0, 1, 0 ,1, …]. Each index in this vector represents a specific attribute, such as whether the individual is smiling (first index), or if the individual has long hair (second index), and so on.

The Jaccard similarity is particularly effective in this scenario because it calculates the similarity or diversity between two sets. This is done by comparing the intersection and the union of the sets, as illustrated in the attached diagram. Consequently, individuals who share a greater number of characteristics will have a higher similarity score. This makes the Jaccard similarity a powerful tool for our problem.

That's why we used Jaccard similarity to form clusters i.e. points, which attribute vectors more similar to each other in jaccard similarity, were put in one group/cluster/class.

Firstly, we observe that while it seems natural to us to cluster faces on striking features, the model considers more information and thus considers different images to be close together. This is intuitive, considering how large of a portion the background of an image usually has. Secondly, we observe that this disentanglement of the clusters transfers to the latent space. For similar reasons this is intuitive. If a general disentanglement of clusters in the projection moving to the latent space occurs requires more reserarch and more precise clustering.

Disregarding clusters, one can observe that a general structure of the latent space is hard to discern. Thus also making analysis difficult. We do observe however that our latent space mirrors the shape of a normal distribution, which makes sense given we impose this assumption onto our model.

Interpolation: ConvVAE¶

In [ ]:
device = 'cpu'
In [ ]:
cvae = ConvVAE().to(device)
optimizer = torch.optim.Adam(cvae.parameters(), lr=3e-4)
cvae, optimizer, epoch, stats = load_model(cvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_8/models/checkpoint_epoch_48.pth')
In [ ]:
resnetvae_KL_lower = ResNetVAE().to(device)
optimizer = torch.optim.Adam(resnetvae_KL_lower.parameters(), lr=1e-3)
resnetvae_KL_lower, optimizer, epoch, stats = load_model(resnetvae_KL_lower, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120241/models/checkpoint_epoch_48.pth')
In [ ]:
resnetvae_KL_higher = ResNetVAE().to(device)
optimizer = torch.optim.Adam(resnetvae_KL_higher.parameters(), lr=1e-3)
resnetvae_KL_higher, optimizer, epoch, stats = load_model(resnetvae_KL_higher, optimizer, '/home/user/rogf1/CudaVisionWS23/Assignment5/experiments/experiment_40120242/models/checkpoint_epoch_48.pth')
In [ ]:
@torch.no_grad()
def plot_reconstructed(model, test_imgs, N=25):
    # pdb.set_trace()
    model = model.eval()
    _,(z,_, _) = model(test_imgs)
    z1, z2, z3, z4 = z
    # checkpoints = [0,0.25,0.5,0.75,1]
    checkpoints = [0,0.125,0.25,0.375,0.5,0.625,0.75,0.875,1]
    interpolated_latent_vectors_LHS = [lam * z3 + (1-lam) * z2 for lam in checkpoints]
    interpolated_latent_vectors_RHS = [lam * z4 + (1-lam) * z1 for lam in checkpoints]
    for LHS,RHS in zip(interpolated_latent_vectors_LHS, interpolated_latent_vectors_RHS):
        interpolated_row = [lam * LHS + (1-lam) * RHS for lam in checkpoints]

        fig, ax = plt.subplots(1, len(checkpoints), figsize=(27,3), sharey=True, sharex=True)
        for col,img in enumerate(interpolated_row):
            ax[col].axis('off')
            recon = model.decoder(img.view(1, img.size(0)))
            recon = recon.view(3, 128, 128).cpu()
            img = recon.detach().numpy().reshape(3, 128, 128).transpose(1, 2, 0)
            maxValue = np.amax(img)
            minValue = np.amin(img)
            img = np.clip(img, 0, 1)
            ax[col].imshow(img)

        fig.tight_layout(pad=0, h_pad=0)
In [ ]:
plot_reconstructed(resnetvae_KL_lower, test_data[0:4])
In [ ]:
plot_reconstructed(resnetvae_KL_higher, test_data[0:4])
In [ ]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)
plot_reconstructed(cvae, test_data[0:4])

Discussion¶

Most notable face interpolation is produced by ResNetVAE model (lower KLD weight). There, change of faces can actually be seen. Also, the quality of reconstructed images is the best there. Second place belongs to ConvVAE, which interpolation consists of faces that change only a little bit, but one could argue that it still looks better that ResNetVAE with higher KLD weight, where it can be clearly seen that the there's hardly any interpolation, because the face remains the same in all the interpolated images.

Bonus Task: VQVAE¶

In [ ]:
def save_model(model, optimizer, epoch, stats, exp_no = 7):
    """ Saving model checkpoint """
    
    if(not os.path.exists("experiments/experiment_"+str(exp_no)+"/models")):
        os.makedirs("experiments/experiment_"+str(exp_no)+"/models")
    savepath = "experiments/experiment_"+str(exp_no)+f"/models/checkpoint_epoch_{epoch}.pth"

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'stats': stats
    }, savepath)
    
    return


def load_model(model, optimizer, savepath):
    """ Loading pretrained checkpoint """
    
    checkpoint = torch.load(savepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint["epoch"]
    stats = checkpoint["stats"]
    
    return model, optimizer, epoch, stats
In [ ]:
def train_epoch(model, train_loader, optimizer, criterion, epoch, device):
    """ Training a model for one epoch """
    
    loss_list = []
    recons_loss = []
    vae_loss = []
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, (images, _) in progress_bar:
        images = images.to(device)
        
        # Clear gradients w.r.t. parameters
        optimizer.zero_grad()
         
        # Forward pass
        recons, quant_loss = model(images)
         
        # Calculate Loss
        loss, mse = criterion(recons, quant_loss, images)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())

        # Getting gradients w.r.t. parameters
        loss.backward()
         
        # Updating parameters
        optimizer.step()
        
        progress_bar.set_description(f"Epoch {epoch+1} Iter {i+1}: loss {loss.item():.5f}. ")
        
    mean_loss = np.mean(loss_list)
    
    return mean_loss, loss_list


@torch.no_grad()
def eval_model(model, eval_loader, criterion, device, epoch=None, savefig=False, savepath="", writer=None):
    """ Evaluating the model for either validation or test """
    loss_list = []
    recons_loss = []
    
    for i, (images, _) in enumerate(eval_loader):
        images = images.to(device)
        
        # Forward pass 
        recons, quant_loss = model(images)
                 
        loss, mse = criterion(recons, quant_loss, images)
        loss_list.append(loss.item())
        recons_loss.append(mse.item())
        
        if(i==0 and savefig):
            save_image(recons[:64].cpu(), os.path.join(savepath, f"recons{epoch}.png"))
            
    # Total correct predictions and loss
    loss = np.mean(loss_list)
    recons_loss = np.mean(recons_loss)

    return loss, recons_loss


def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader,
                num_epochs, savepath, writer, save_frequency=2):
    """ Training a model for a given number of epochs"""
    
    train_loss = []
    val_loss =  []
    val_loss_recons =  []
    val_loss_kld =  []
    loss_iters = []
    
    for epoch in range(num_epochs):
           
        # validation epoch
        model.eval()  # important for dropout and batch norms
        log_epoch = (epoch % save_frequency == 0 or epoch == num_epochs - 1)
        loss, recons_loss= eval_model(
                model=model, eval_loader=valid_loader, criterion=criterion,
                device=device, epoch=epoch, savefig=log_epoch, savepath=savepath,
                writer=writer
            )
        val_loss.append(loss)
        val_loss_recons.append(recons_loss)
        
        # training epoch
        model.train()  # important for dropout and batch norms
        mean_loss, cur_loss_iters = train_epoch(
                model=model, train_loader=train_loader, optimizer=optimizer,
                criterion=criterion, epoch=epoch, device=device
            )
        
        # PLATEAU SCHEDULER
        scheduler.step(val_loss[-1])
        train_loss.append(mean_loss)
        loss_iters = loss_iters + cur_loss_iters
        
        if(epoch % save_frequency == 0):
            stats = {
                "train_loss": train_loss,
                "valid_loss": val_loss,
                "loss_iters": loss_iters
            }
            save_model(model=model, optimizer=optimizer, epoch=epoch, stats=stats)
        
        if(log_epoch):
            print(f"    Train loss: {round(mean_loss, 5)}")
            print(f"    Valid loss: {round(loss, 5)}")
            print(f"       Valid loss recons: {round(val_loss_recons[-1], 5)}")
    
    print(f"Training completed")
    return train_loss, val_loss, loss_iters, val_loss_recons
In [ ]:
# Code was taken from: https://colab.research.google.com/github/zalandoresearch/pytorch-vq-vae/blob/master/vq-vae.ipynb
# Note, that the commitment_cost is the beta parameter from the paper

class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        print("input Shape", inputs.shape)
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        print("Encodings Shape", encodings.shape)
        encodings.scatter_(1, encoding_indices, 1)
        print(encodings.shape)
        print("encodings are one hots", encodings[1].sum())
        
        # Quantize and unflatten
        print("quantized shape", torch.matmul(encodings, self._embedding.weight).shape)
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
In [ ]:
class VQVAE(nn.Module):
    def __init__(self):
        super(VQVAE, self).__init__()
        
        # Define Convolutional Encoders
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels = 3, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
            nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(),
        )

        self.VQ = VectorQuantizer(embedding_dim=64, num_embeddings=512, commitment_cost=0.25)

        # Define decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels = 64, out_channels = 32, kernel_size = 4, stride = 2, padding = 1),
            nn.ReLU(), 
            nn.ConvTranspose2d(in_channels = 32, out_channels = 16, kernel_size = 4, stride = 2, padding = 1),
            nn.Tanh(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 3, kernel_size = 4, stride = 2, padding = 1),
        )
    
    def forward(self, x):
        x = self.encoder(x)
        q_loss, x, ppl, encodings = self.VQ(x)
        out = self.decoder(x) 

        return out, q_loss

def vqvae_loss_function(recons, quantize_losses, target):
    recons_loss = F.mse_loss(recons, target)
    
    loss = recons_loss + quantize_losses

    return loss, quantize_losses
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
vqvae = VQVAE()
criterion = vqvae_loss_function
optimizer = torch.optim.Adam(vqvae.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience = 3, factor = 0.5, verbose = True)
vqvae = vqvae.to(device)
In [ ]:
vqvae, optimizer, epoch, stats = load_model(vqvae, optimizer, '/home/user/lschulze/projects/CudaVisionWS23/Assignment5/experiments/experiment_7/models/checkpoint_epoch_18.pth')

Reconstruction of images¶

In [ ]:
test_data, labels = next(iter(testloader))
test_data = test_data.to(device)

with torch.no_grad():
    sample, _ = vqvae(test_data)

    recons = sample.view(batch_size, 3, 128, 128).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(128, 128))  # Adjust figsize as needed

    for i in range(10):
        img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
        # test_img = test_data[i].reshape(3, 128, 128).transpose(1, 2, 0)
        test_img = test_data[i].cpu().numpy().reshape(3,128,128).transpose(1, 2, 0)
        axes[0][i].imshow(test_img)
        axes[0][i].axis('off')
        axes[1][i].imshow(img)
        axes[1][i].axis('off')  # Turn off axis labels for clarity

    plt.tight_layout()
    plt.show()
/home/user/lschulze/anaconda3/envs/lab/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
  warnings.warn(
/home/user/lschulze/anaconda3/envs/lab/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).
  warnings.warn(
input Shape torch.Size([64, 64, 16, 16])
Encodings Shape torch.Size([16384, 512])
torch.Size([16384, 512])
encodings are one hots tensor(1., device='cuda:0')
quantized shape torch.Size([16384, 64])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

The reconstructions of our VQVAE look very nice and very detailed. Comparing the reconstructions to the ones above those rank clearly the highest, showing the highest amount of detail. Unfortunately, we have some visual artifacts in our reconstructed images. However, given our prior findings, we are very confident that a more complex encoder/decoder would completely resolve the issue.

Generation of images using uniform prior¶

In [ ]:
def choose_rows_uniformly(matrix, num_rows_to_choose):
    # Get the total number of rows in the matrix
    total_rows = matrix.size(0)

    # Generate random indices for selecting rows
    random_indices = torch.randint(0, total_rows, (num_rows_to_choose,), dtype=torch.long)

    # Use the random indices to select rows from the matrix
    selected_rows = matrix[random_indices]

    return selected_rows

# Example usage:
# Assuming 'your_matrix' is a PyTorch tensor with shape (num_rows, num_columns)
your_matrix = vqvae.VQ._embedding.weight
num_rows_to_choose = 16384

selected_rows = choose_rows_uniformly(your_matrix, num_rows_to_choose)
print("Selected Rows:")
print(selected_rows.shape)
Selected Rows:
torch.Size([16384, 64])
In [ ]:
with torch.no_grad():
    z = selected_rows.reshape(64, 64, 16, 16).to(device)
    # z = torch.rand(64, 64, 16, 16).to(device)
    # z = vqvae.encoder(z)
    # _, z, _ , _ = vqvae.VQ(z)
    sample = vqvae.decoder(z)

recons = sample.view(64, 3, 128, 128).cpu()
torch.Size([64, 3, 128, 128])
Out[ ]:
torch.Size([64, 3, 128, 128])
In [ ]:
fig, axes = plt.subplots(1, 10, figsize=(128, 128))  # Adjust figsize as needed

for i in range(10):
    img = recons[i].numpy().reshape(3, 128, 128).transpose(1, 2, 0)
    axes[i].imshow(img)
    axes[i].axis('off')  # Turn off axis labels for clarity

plt.tight_layout()
plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

The generation of images using VQVAE was not successful unfortunately, clearly ranking last when compared to the above results. We consider the reason for this to be uniform prior being too weak of an assumption. Considering, that during reconstruction our encoded vectors have a very particular structure based on the input, this seems intuitive. Using a probability distribution given by an RNN or LSTM, trained on existent latent representations, would therefore be the better choice.

Frechet Inception Distance: VQVAE¶

In [ ]:
from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize = True)
# Torchmetrics implementation does not use cuda for some reason
device = 'cpu'
vqvae = vqvae.to(device)

test_dataset = torchvision.datasets.CelebA(root='./data', split='test',
                                       download=True, transform=transform)

testloader = torch.utils.data.DataLoader(test_dataset, batch_size = 64)

progress_bar = tqdm(enumerate(testloader), total=len(testloader))
for i, (test_data, _) in progress_bar:
    test_data = test_data.to(device)
    selected_rows = choose_rows_uniformly(your_matrix, num_rows_to_choose)
    z = selected_rows.reshape(64, 64, 16, 16).to(device)
    sample = vqvae.decoder(z)
    fid.update(sample, real=False)
    fid.update(test_data, real=True)
fid.compute()
Files already downloaded and verified
100%|██████████| 312/312 [08:57<00:00,  1.72s/it]
Out[ ]:
tensor(85.4611)

The FID score of the VQVAE results are the highest compared to all prior models. This is expected given the viewed reconstructions